﻿using IOP.Models.Query;
using Microsoft.EntityFrameworkCore;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Dapper;
using System.Data;

namespace IOP.Extension.Entity
{
    /// <summary>
    /// 数据库实体框架扩展 
    /// </summary>
    public static class EntityExtensions
    {
        /// <summary>
        /// 添加数据至数据库
        /// </summary>
        /// <typeparam name="TModel">实体模型</typeparam>
        /// <param name="db">数据库上下文</param>
        /// <param name="model">数据传输实体</param>
        /// <returns></returns>
        public static TModel Create<TModel>(this DbContext db, TModel model)
            where TModel : class
        {
            db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
            try
            {
                db.Set<TModel>().Add(model);
                db.SaveChanges();
                return model;
            }
            catch (Exception e)
            {
                throw e;
            }
        }

        /// <summary>
        /// 添加数据至数据库
        /// </summary>
        /// <typeparam name="TModel">实体模型</typeparam>
        /// <param name="db">数据库上下文</param>
        /// <param name="model">数据传输实体</param>
        /// <returns></returns>
        public static async Task<TModel> CreateAsync<TModel>(this DbContext db, TModel model)
            where TModel : class
        {
            db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
            try
            {
                await db.Set<TModel>().AddAsync(model);
                db.SaveChanges();
                return model;
            }
            catch (Exception e)
            {
                throw e;
            }
        }

        /// <summary>
        /// 添加多个实体
        /// </summary>
        /// <typeparam name="TModel">实体模型</typeparam>
        /// <param name="db">数据库上下文</param>
        /// <param name="models">数据传输对象集合</param>
        /// <returns></returns>
        public static bool CreateRange<TModel>(this DbContext db, IEnumerable<TModel> models)
            where TModel : class
        {
            using (var trans = db.Database.BeginTransaction())
            {
                db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                try
                {
                    db.Set<TModel>().AddRange(models);
                    db.SaveChanges();
                    trans.Commit();
                }

                catch (Exception e)
                {
                    trans.Rollback();
                    throw e;
                }
            }
            return true;
        }

        /// <summary>
        /// 创建或者更新实体
        /// </summary>
        /// <typeparam name="TModel">实体模型</typeparam>
        /// <param name="db"></param>
        /// <param name="model"></param>
        /// <returns></returns>
        public static TModel CreateOrUpdate<TModel>(this DbContext db, TModel model, string idName = "")
            where TModel : class
        {
            try
            {
                db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                var id = string.IsNullOrEmpty(idName) ? "Id" : idName;
                db.ChangeTracker.TrackGraph(model, node =>
                {
                    var entry = node.Entry;
                    if (entry.IsKeySet)
                        entry.State = EntityState.Modified;
                    else
                    {
                        entry.State = EntityState.Added;
                        entry.Property(id).IsTemporary = true;
                    }
                });
                db.SaveChanges();
                return model;
            }
            catch (Exception e)
            {
                throw e;
            }
        }

        /// <summary>
        /// 创建或者更新
        /// </summary>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="db"></param>
        /// <param name="model"></param>
        /// <param name="idName"></param>
        /// <returns></returns>
        public static async Task<TModel> CreateOrUpdateAsync<TModel>(this DbContext db, TModel model, string idName = "")
        {
            Exception exception = null;
            var id = string.IsNullOrEmpty(idName) ? "Id" : idName;
            await Task.Factory.StartNew(() =>
            {
                try
                {
                    db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                    db.ChangeTracker.TrackGraph(model, node =>
                    {
                        var entry = node.Entry;
                        if (entry.IsKeySet)
                            entry.State = EntityState.Modified;
                        else
                        {
                            entry.State = EntityState.Added;
                            entry.Property(id).IsTemporary = true;
                        }
                    });
                    db.SaveChanges();
                }
                catch (Exception e)
                {
                    exception = e;
                }
            });
            if (exception != null) throw exception;
            else return model;
        }

        /// <summary>
        /// 创建或者更新多个实体
        /// </summary>
        /// <typeparam name="TModel">实体模型</typeparam>
        /// <param name="db">数据库上下文</param>
        /// <param name="models">数据传输对象集合</param>
        /// <returns></returns>
        public static bool CreateOrUpdateRange<TModel>(this DbContext db, IEnumerable<TModel> models, string idName = "")
            where TModel : class
        {
            var id = string.IsNullOrEmpty(idName) ? "Id" : idName;
            using (var trans = db.Database.BeginTransaction())
            {
                try
                {
                    db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                    foreach (var item in models)
                    {
                        db.ChangeTracker.TrackGraph(item, node =>
                        {
                            var entry = node.Entry;

                            if ((int)entry.Property(id).CurrentValue <= 0)
                            {
                                entry.State = EntityState.Added;
                                entry.Property(id).IsTemporary = true;
                            }
                            else entry.State = EntityState.Modified;
                        });
                    }
                    db.SaveChanges();
                    trans.Commit();
                }
                catch (Exception e)
                {
                    trans.Rollback();
                    throw e;
                }
            }
            return true;
        }

        /// <summary>
        /// 删除实体
        /// </summary>
        /// <param name="db">数据库上下文</param>
        /// <param name="models"></param>
        /// <returns></returns>
        public static bool DeleteRange<TModel>(this DbContext db, IEnumerable<TModel> models)
            where TModel : class
        {
            using (var trans = db.Database.BeginTransaction())
            {
                db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                try
                {
                    db.RemoveRange(models);
                    db.SaveChanges();
                    trans.Commit();
                }
                catch (Exception e)
                {
                    trans.Rollback();
                    throw e;
                }
            }
            return true;
        }

        /// <summary>
        /// 删除实体
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="db"></param>
        /// <param name="tableName"></param>
        /// <param name="ids"></param>
        /// <param name="idName"></param>
        /// <returns></returns>
        public static async Task<bool> DeleteRangeAsync<T>(this DbContext db, string tableName, IEnumerable<T> ids, string idName = "")
            where T : IComparable<T>
        {
            try
            {
                var id = string.IsNullOrEmpty(idName) ? "Id" : idName;
                var conn = db.Database.GetDbConnection();
                int r = await conn.ExecuteAsync($"DELETE FROM {tableName} WHERE {id} IN @ids", new { ids });
                return true;
            }
            catch (Exception)
            {
                throw;
            }
        }

        /// <summary>
        /// 删除实体
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="db"></param>
        /// <param name="tableName"></param>
        /// <param name="id"></param>
        /// <param name="idName"></param>
        /// <returns></returns>
        public static async Task<bool> DeleteAsync<T>(this DbContext db, string tableName, T id, string idName = "")
            where T : IComparable<T>
        {
            try
            {
                var ids = string.IsNullOrEmpty(idName) ? "Id" : idName;
                var conn = db.Database.GetDbConnection();
                int r = await conn.ExecuteAsync($"DELETE FROM {tableName} WHERE {ids} = @id", new { id });
                return true;
            }
            catch (Exception)
            {
                throw;
            }
        }

        /// <summary>
        /// 执行查询
        /// </summary>
        /// <typeparam name="TDto"></typeparam>
        /// <param name="db"></param>
        /// <param name="sql"></param>
        /// <param name="parameters"></param>
        /// <param name="commandType"></param>
        /// <returns></returns>
        public static async Task<IEnumerable<TDto>> QueryAsync<TDto>(this DbContext db, string sql, object parameters = null, CommandType? commandType = null)
            where TDto : class
        {
            try
            {
                var conn = db.Database.GetDbConnection();
                if (parameters is IDictionary<string, object>)
                {
                    DynamicParameters dParameters = new DynamicParameters();
                    foreach (var item in (parameters as IDictionary<string, object>))
                    {
                        dParameters.Add(item.Key, item.Value);
                    }
                    var query = await conn.QueryAsync<TDto>(sql, dParameters, commandType: commandType);
                    return query;
                }
                else if (commandType == CommandType.StoredProcedure)
                {
                    List<string> list = new List<string>();
                    foreach (var item in parameters.GetType().GetProperties())
                    {
                        var para = $" @{item.Name} ";
                        list.Add(para);
                    }
                    var paras = string.Join(",", list);
                    sql = sql + paras;
                    var query = await conn.QueryAsync<TDto>(sql, parameters);
                    return query;
                }
                else
                {
                    var query = await conn.QueryAsync<TDto>(sql, parameters, commandType: commandType);
                    return query;
                }
            }
            catch (Exception e)
            {
                throw e;
            }
        }

        /// <summary>
        /// 执行查询
        /// </summary>
        /// <typeparam name="TDto"></typeparam>
        /// <typeparam name="TDbContext"></typeparam>
        /// <param name="db"></param>
        /// <param name="lambda"></param>
        /// <returns></returns>
        public static async Task<List<TDto>> QueryAsync<TDbContext, TDto>(this TDbContext db, IQueryable<TDto> lambda)
            where TDbContext : DbContext
        {
            try
            {
                db.ChangeTracker.QueryTrackingBehavior = QueryTrackingBehavior.NoTracking;
                return await lambda.ToListAsync();
            }
            catch (Exception e)
            {
                throw e;
            }
        }

        /// <summary>
        /// 构建Where表达式
        /// </summary>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="parameters"></param>
        /// <param name="parametersName"></param>
        /// <returns></returns>
        public static Expression<Func<TModel, bool>> BuildWhereLambda<TModel>(this IEnumerable<WhereParameter> parameters, string parametersName)
        {
            try
            {
                ParameterExpression parameter = Expression.Parameter(typeof(TModel), parametersName);
                Expression result = null;
                bool foundFirst = false;
                foreach (var item in parameters)
                {
                    if (string.IsNullOrEmpty(item.Name) || item.Value == null) continue;
                    MemberExpression member = Expression.PropertyOrField(parameter, item.Name);
                    ConstantExpression constant = Expression.Constant(item.Value);
                    UnaryExpression convert = null;
                    var memberType = typeof(TModel).GetProperty(item.Name).PropertyType;
                    Expression binary = null;
                    switch (item.Operator)
                    {
                        case Operator.Equal:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.Equal, member, convert);
                            break;
                        case Operator.GreaterThan:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.GreaterThan, member, convert);
                            break;
                        case Operator.GreaterThanOrEqual:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.GreaterThanOrEqual, member, convert);
                            break;
                        case Operator.IN:
                            binary = Expression.Call(typeof(Enumerable), "Contains", new Type[] { memberType }, new Expression[] { constant, member });
                            break;
                        case Operator.LessThan:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.LessThan, member, convert);
                            break;
                        case Operator.LessThanOrEqual:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.LessThanOrEqual, member, convert);
                            break;
                        case Operator.NotEqual:
                            convert = Expression.Convert(constant, memberType);
                            binary = Expression.MakeBinary(ExpressionType.NotEqual, member, convert);
                            break;
                    }
                    if (foundFirst)
                    {
                        switch (item.Connector)
                        {
                            case Connector.AND:
                                result = Expression.AndAlso(result, binary);
                                break;
                            case Connector.OR:
                                result = Expression.OrElse(result, binary);
                                break;
                        }
                    }
                    else
                    {
                        result = binary;
                        foundFirst = true;
                    }
                }
                return Expression.Lambda<Func<TModel, bool>>(result, parameter);
            }
            catch (Exception e)
            {
                throw e;
            }
        }
    }
}
